package com.carrot.uaa.infrastructure.redis;

import cn.hutool.core.util.StrUtil;
import com.carrot.uaa.constant.Oauth2Constants;
import com.carrot.uaa.constant.SecurityConstants;
import com.carrot.uaa.entity.OAuth2User;
import com.carrot.uaa.entity.RedisAuthorization;
import com.carrot.uaa.infrastructure.mapper.SysTenantOrderDetailMapper;
import com.carrot.uaa.infrastructure.service.SysUserLoginLogService;
import com.carrot.uaa.infrastructure.service.UserService;
import com.carrot.uaa.util.RedisUtils;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import org.springframework.dao.DataRetrievalFailureException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.oauth2.core.*;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.jackson2.OAuth2AuthorizationServerJackson2Module;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import java.security.Principal;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

/**
 * 授权仓储实现
 */
@Service
@RequiredArgsConstructor
public class RedisAuthorizationService implements OAuth2AuthorizationService {

    private final RegisteredClientRepository registeredClientRepository;
    private final RedisAuthorizationRepository oAuth2AuthorizationRepository;
    private final RedisOperator<Object> redisOperator;
    private final UserService userService;
    private final SysUserLoginLogService sysUserLoginLogService;
    private final SysTenantOrderDetailMapper sysTenantOrderDetailMapper;
    private final RedisTokenUtils redisTokenUtils;

    private final static ObjectMapper MAPPER = new ObjectMapper();

    static {
        // 初始化序列化配置
        ClassLoader classLoader = RedisAuthorizationService.class.getClassLoader();
        // 加载security提供的Modules
        List<Module> modules = SecurityJackson2Modules.getModules(classLoader);
        MAPPER.registerModules(modules);
        // 加载Authorization Server提供的Module
        MAPPER.registerModule(new OAuth2AuthorizationServerJackson2Module());
    }


    @Override
    public void save(OAuth2Authorization authorization) {
        Optional<RedisAuthorization> existingAuthorization = oAuth2AuthorizationRepository.findById(authorization.getId());
        // 如果已存在则删除后再保存
        existingAuthorization.map(RedisAuthorization::getId)
                .ifPresent(oAuth2AuthorizationRepository::deleteById);

        RedisAuthorization entity = toEntity(authorization);

        //设置过期时间 7天
        entity.setTimeout(24*60*7L);

        //保存至redis
        oAuth2AuthorizationRepository.save(entity);
        if(StrUtil.isNotBlank(entity.getAccessTokenValue())) {
            redisOperator.expire(RedisUtils
                    .getAuthorizationAccessTokenIdxKey(entity.getAccessTokenValue()), entity.getTimeout(), TimeUnit.MINUTES);
        }
        if(StrUtil.isNotBlank(entity.getRefreshTokenValue())){
            redisOperator.expire(RedisUtils
                    .getAuthorizationRefreshTokenIdxKey(entity.getRefreshTokenValue()), entity.getTimeout(), TimeUnit.MINUTES);
        }
        //保存用户在线令牌
        HttpServletRequest request =((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
        String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
        String clientIp = request.getParameter(Oauth2Constants.CLIENT_IP);
        String requestURI = request.getRequestURI();
        Authentication authentication= (Authentication) authorization.getAttributes().get(Principal.class.getName());
        if(authentication!=null&&authentication.getPrincipal() instanceof  OAuth2User ){
            if(AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)
                    || SecurityConstants.GRANT_TYPE_SMS_CODE.equals(grantType)){
                OAuth2User oAuth2User= (OAuth2User) authentication.getPrincipal();
                saveOnLineToken(entity,oAuth2User);
            }
            //维护登录时间和ip
            if(SecurityConstants.GRANT_TYPE_SMS_CODE.equals(grantType)){
                OAuth2User oAuth2User= (OAuth2User) authentication.getPrincipal();
                userService.updateUserLoginTimeAndIp(oAuth2User.getUid(), clientIp,oAuth2User.getUserType());
                //登录日志
                sysUserLoginLogService.save(oAuth2User,clientIp);
            }
        }

        if(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(grantType)){
            //客户端模式
            saveOnLineToken(entity);
        }

        //令牌撤销时下线令牌
        if(SecurityConstants.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI.equals(requestURI)){
            String revokeToken = request.getParameter(Oauth2Constants.TOKEN);
            redisTokenUtils.revokeToken(revokeToken,entity.getRefreshTokenValue(),entity.getAccessTokenValue());
        }

    }



    @Override
    public void remove(OAuth2Authorization authorization) {
        Assert.notNull(authorization, "authorization cannot be null");
        oAuth2AuthorizationRepository.deleteById(authorization.getId());
    }

    @Override
    public OAuth2Authorization findById(String id) {
        Assert.hasText(id, "id cannot be empty");
        return oAuth2AuthorizationRepository.findById(id)
                .map(this::toObject).orElse(null);
    }

    @Override
    public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
        Assert.hasText(token, "token cannot be empty");

        Optional<RedisAuthorization> result;

        if (tokenType == null) {
            result = oAuth2AuthorizationRepository.findByState(token)
                    .or(() -> oAuth2AuthorizationRepository.findByAuthorizationCodeValue(token))
                    .or(() -> oAuth2AuthorizationRepository.findByAccessTokenValue(token))
                    .or(() -> oAuth2AuthorizationRepository.findByOidcIdTokenValue(token))
                    .or(() -> oAuth2AuthorizationRepository.findByRefreshTokenValue(token))
                    .or(() -> oAuth2AuthorizationRepository.findByUserCodeValue(token))
                    .or(() -> oAuth2AuthorizationRepository.findByDeviceCodeValue(token));
        } else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
            result = oAuth2AuthorizationRepository.findByState(token);
        } else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
            result = oAuth2AuthorizationRepository.findByAuthorizationCodeValue(token);
        } else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) {
            result = oAuth2AuthorizationRepository.findByAccessTokenValue(token);
        } else if (OidcParameterNames.ID_TOKEN.equals(tokenType.getValue())) {
            result = oAuth2AuthorizationRepository.findByOidcIdTokenValue(token);
        } else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) {
            result = oAuth2AuthorizationRepository.findByRefreshTokenValue(token);
            //判断refreshToken是否被主动下线
            if(result.isPresent()&&!redisOperator.hseKey(RedisUtils.getUserOnLineRefreshTokenKey(token))){
                RedisAuthorization redisAuthorization = result.get();
                Instant refreshTokenExpiresAt = redisAuthorization.getRefreshTokenExpiresAt();
                if(refreshTokenExpiresAt != null && Instant.now().isBefore(refreshTokenExpiresAt)){
                    throw new OAuth2AuthenticationException("account_exception");
                }
            }
        } else if (OAuth2ParameterNames.USER_CODE.equals(tokenType.getValue())) {
            result = oAuth2AuthorizationRepository.findByUserCodeValue(token);
        } else if (OAuth2ParameterNames.DEVICE_CODE.equals(tokenType.getValue())) {
            result = oAuth2AuthorizationRepository.findByDeviceCodeValue(token);
        } else {
            result = Optional.empty();
        }
        return result.map(this::toObject).orElse(null);
    }

    private void saveOnLineToken(RedisAuthorization entity,OAuth2User oAuth2User) {
        String refreshTokenValue = entity.getRefreshTokenValue();
        String accessTokenValue = entity.getAccessTokenValue();
        String registeredClientId = entity.getRegisteredClientId();
        Long nowTime = Instant.now().getEpochSecond();
        String onLineTokenKey= RedisUtils.getUserOnLineTokenKey(oAuth2User.getUid());

        //下线同一个账号同一个客户端的其他在线令牌
        redisTokenUtils.delOtherOnlineToken(onLineTokenKey,registeredClientId);

        Long onLineTokenKeyExp=null;
        //保存用户在线刷新令牌
        if(refreshTokenValue!=null){
            String onLineRefreshTokenKey= RedisUtils.getUserOnLineRefreshTokenKey(refreshTokenValue);
            Long refreshExp=Long.valueOf(entity.getRefreshTokenExpiresAt().getEpochSecond())-nowTime;
            redisTokenUtils.saveRefreshToken(onLineRefreshTokenKey,oAuth2User,refreshExp);
            onLineTokenKeyExp=refreshExp;
        }

        //保存用户在线访问令牌
        String onLineAccessTokenKey= RedisUtils.getUserOnLineAccessTokenKey(accessTokenValue);
        Long accessExp = Long.valueOf(entity.getAccessTokenExpiresAt().getEpochSecond())-nowTime;
        redisTokenUtils.saveAccessToken(onLineAccessTokenKey,oAuth2User,accessExp);

        //保存用户在线令牌
        onLineTokenKeyExp=onLineTokenKeyExp==null?accessExp:onLineTokenKeyExp;
        Long onLineTokenKeyRemainExp = redisOperator.getExpire(onLineTokenKey, TimeUnit.SECONDS);
        onLineTokenKeyExp=onLineTokenKeyRemainExp>onLineTokenKeyExp?onLineTokenKeyRemainExp:onLineTokenKeyExp;
        Map<String,String> tokenMap=new HashMap<>(2);
        tokenMap.put(Oauth2Constants.REFRESH_TOKEN,refreshTokenValue);
        tokenMap.put(Oauth2Constants.ACCESS_TOKEN,accessTokenValue);
        redisTokenUtils.saveToken(onLineTokenKey,registeredClientId,tokenMap,onLineTokenKeyExp);
    }

    /**
     * @param entity
     */
    private void saveOnLineToken(RedisAuthorization entity){
        String accessTokenValue = entity.getAccessTokenValue();
        Long nowTime = Instant.now().getEpochSecond();
        String clientId = entity.getPrincipalName();

        //获取客户端的有效期
        Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
        OAuth2ClientAuthenticationToken oAuth2ClientAuthenticationToken=(OAuth2ClientAuthenticationToken)clientPrincipal;
        Instant clientSecretExpiresAt = oAuth2ClientAuthenticationToken.getRegisteredClient().getClientSecretExpiresAt();
        //计算token的有效期
        Instant accessTokenExpiresAt = entity.getAccessTokenExpiresAt();
        if(accessTokenExpiresAt.isAfter(clientSecretExpiresAt)){
            accessTokenExpiresAt=clientSecretExpiresAt;
        }

        //根据客户端ID查询租户ID
        String tenantId = sysTenantOrderDetailMapper.selectTenantIdByClientId(clientId);
        String tenantOnLineTokenKey = RedisUtils.getTenantOnLineTokenKey(tenantId);

        //下线同一个账号同一个客户端的其他在线令牌
        redisTokenUtils.delOtherOnlineToken(tenantOnLineTokenKey,clientId);

        //保存访问令牌
        String onLineAccessTokenKey= RedisUtils.getUserOnLineAccessTokenKey(accessTokenValue);
        Long accessExp = Long.valueOf(accessTokenExpiresAt.getEpochSecond())-nowTime;
        Map<String,String> value=new HashMap<>(2);
        value.put("tenantId",tenantId);
        value.put("clientId",clientId);
        redisTokenUtils.saveAccessToken(onLineAccessTokenKey,value,accessExp);

        //保存租户在线令牌
        Long onLineTokenKeyExp=accessExp;
        Long onLineTokenKeyRemainExp = redisOperator.getExpire(tenantOnLineTokenKey, TimeUnit.SECONDS);
        onLineTokenKeyExp=onLineTokenKeyRemainExp>onLineTokenKeyExp?onLineTokenKeyRemainExp:onLineTokenKeyExp;
        Map<String,String> tokenMap=new HashMap<>(2);
        tokenMap.put(Oauth2Constants.ACCESS_TOKEN,accessTokenValue);
        redisTokenUtils.saveToken(tenantOnLineTokenKey,clientId,tokenMap,onLineTokenKeyExp);
    }


    /**
     * 将redis中存储的类型转为框架所需的类型
     *
     * @param entity redis中存储的类型
     * @return 框架所需的类型
     */
    private OAuth2Authorization toObject(RedisAuthorization entity) {
        RegisteredClient registeredClient = this.registeredClientRepository.findById(entity.getRegisteredClientId());
        if (registeredClient == null) {
            throw new DataRetrievalFailureException(
                    "The RegisteredClient with id '" + entity.getRegisteredClientId() + "' was not found in the RegisteredClientRepository.");
        }

        OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient)
                .id(entity.getId())
                .principalName(entity.getPrincipalName())
                .authorizationGrantType(resolveAuthorizationGrantType(entity.getAuthorizationGrantType()))
                .authorizedScopes(StringUtils.commaDelimitedListToSet(entity.getAuthorizedScopes()))
                .attributes(attributes -> attributes.putAll(parseMap(entity.getAttributes())));
        if (entity.getState() != null) {
            builder.attribute(OAuth2ParameterNames.STATE, entity.getState());
        }

        if (entity.getAuthorizationCodeValue() != null) {
            OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
                    entity.getAuthorizationCodeValue(),
                    entity.getAuthorizationCodeIssuedAt(),
                    entity.getAuthorizationCodeExpiresAt());
            builder.token(authorizationCode, metadata -> metadata.putAll(parseMap(entity.getAuthorizationCodeMetadata())));
        }

        if (entity.getAccessTokenValue() != null) {
            OAuth2AccessToken accessToken = new OAuth2AccessToken(
                    OAuth2AccessToken.TokenType.BEARER,
                    entity.getAccessTokenValue(),
                    entity.getAccessTokenIssuedAt(),
                    entity.getAccessTokenExpiresAt(),
                    StringUtils.commaDelimitedListToSet(entity.getAccessTokenScopes()));
            builder.token(accessToken, metadata -> metadata.putAll(parseMap(entity.getAccessTokenMetadata())));
        }

        if (entity.getRefreshTokenValue() != null) {
            OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
                    entity.getRefreshTokenValue(),
                    entity.getRefreshTokenIssuedAt(),
                    entity.getRefreshTokenExpiresAt());
            builder.token(refreshToken, metadata -> metadata.putAll(parseMap(entity.getRefreshTokenMetadata())));
        }

        if (entity.getOidcIdTokenValue() != null) {
            OidcIdToken idToken = new OidcIdToken(
                    entity.getOidcIdTokenValue(),
                    entity.getOidcIdTokenIssuedAt(),
                    entity.getOidcIdTokenExpiresAt(),
                    parseMap(entity.getOidcIdTokenClaims()));
            builder.token(idToken, metadata -> metadata.putAll(parseMap(entity.getOidcIdTokenMetadata())));
        }

        if (entity.getUserCodeValue() != null) {
            OAuth2UserCode userCode = new OAuth2UserCode(
                    entity.getUserCodeValue(),
                    entity.getUserCodeIssuedAt(),
                    entity.getUserCodeExpiresAt());
            builder.token(userCode, metadata -> metadata.putAll(parseMap(entity.getUserCodeMetadata())));
        }

        if (entity.getDeviceCodeValue() != null) {
            OAuth2DeviceCode deviceCode = new OAuth2DeviceCode(
                    entity.getDeviceCodeValue(),
                    entity.getDeviceCodeIssuedAt(),
                    entity.getDeviceCodeExpiresAt());
            builder.token(deviceCode, metadata -> metadata.putAll(parseMap(entity.getDeviceCodeMetadata())));
        }

        return builder.build();
    }

    /**
     * 将框架所需的类型转为redis中存储的类型
     *
     * @param authorization 框架所需的类型
     * @return redis中存储的类型
     */
    private RedisAuthorization toEntity(OAuth2Authorization authorization) {
        RedisAuthorization entity = new RedisAuthorization();
        entity.setId(authorization.getId());
        entity.setRegisteredClientId(authorization.getRegisteredClientId());
        entity.setPrincipalName(authorization.getPrincipalName());
        entity.setAuthorizationGrantType(authorization.getAuthorizationGrantType().getValue());
        entity.setAuthorizedScopes(StringUtils.collectionToDelimitedString(authorization.getAuthorizedScopes(), ","));
        entity.setAttributes(writeMap(authorization.getAttributes()));
        entity.setState(authorization.getAttribute(OAuth2ParameterNames.STATE));

        OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
                authorization.getToken(OAuth2AuthorizationCode.class);
        setTokenValues(
                authorizationCode,
                entity::setAuthorizationCodeValue,
                entity::setAuthorizationCodeIssuedAt,
                entity::setAuthorizationCodeExpiresAt,
                entity::setAuthorizationCodeMetadata
        );

        OAuth2Authorization.Token<OAuth2AccessToken> accessToken =
                authorization.getToken(OAuth2AccessToken.class);
        setTokenValues(
                accessToken,
                entity::setAccessTokenValue,
                entity::setAccessTokenIssuedAt,
                entity::setAccessTokenExpiresAt,
                entity::setAccessTokenMetadata
        );
        if (accessToken != null && accessToken.getToken().getScopes() != null) {
            entity.setAccessTokenScopes(StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), ","));
        }

        OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken =
                authorization.getToken(OAuth2RefreshToken.class);
        setTokenValues(
                refreshToken,
                entity::setRefreshTokenValue,
                entity::setRefreshTokenIssuedAt,
                entity::setRefreshTokenExpiresAt,
                entity::setRefreshTokenMetadata
        );

        OAuth2Authorization.Token<OidcIdToken> oidcIdToken =
                authorization.getToken(OidcIdToken.class);
        setTokenValues(
                oidcIdToken,
                entity::setOidcIdTokenValue,
                entity::setOidcIdTokenIssuedAt,
                entity::setOidcIdTokenExpiresAt,
                entity::setOidcIdTokenMetadata
        );
        if (oidcIdToken != null) {
            entity.setOidcIdTokenClaims(writeMap(oidcIdToken.getClaims()));
        }

        OAuth2Authorization.Token<OAuth2UserCode> userCode =
                authorization.getToken(OAuth2UserCode.class);
        setTokenValues(
                userCode,
                entity::setUserCodeValue,
                entity::setUserCodeIssuedAt,
                entity::setUserCodeExpiresAt,
                entity::setUserCodeMetadata
        );

        OAuth2Authorization.Token<OAuth2DeviceCode> deviceCode =
                authorization.getToken(OAuth2DeviceCode.class);
        setTokenValues(
                deviceCode,
                entity::setDeviceCodeValue,
                entity::setDeviceCodeIssuedAt,
                entity::setDeviceCodeExpiresAt,
                entity::setDeviceCodeMetadata
        );

        return entity;
    }

    /**
     * 设置token的值
     *
     * @param token              Token实例
     * @param tokenValueConsumer set方法
     * @param issuedAtConsumer   set方法
     * @param expiresAtConsumer  set方法
     * @param metadataConsumer   set方法
     */
    private void setTokenValues(
            OAuth2Authorization.Token<?> token,
            Consumer<String> tokenValueConsumer,
            Consumer<Instant> issuedAtConsumer,
            Consumer<Instant> expiresAtConsumer,
            Consumer<String> metadataConsumer) {
        if (token != null) {
            OAuth2Token oAuth2Token = token.getToken();
            tokenValueConsumer.accept(oAuth2Token.getTokenValue());
            issuedAtConsumer.accept(oAuth2Token.getIssuedAt());
            expiresAtConsumer.accept(oAuth2Token.getExpiresAt());
            metadataConsumer.accept(writeMap(token.getMetadata()));
        }
    }

    /**
     * 处理授权申请时的 GrantType
     *
     * @param authorizationGrantType 授权申请时的 GrantType
     * @return AuthorizationGrantType的实例
     */
    private static AuthorizationGrantType resolveAuthorizationGrantType(String authorizationGrantType) {
        if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(authorizationGrantType)) {
            return AuthorizationGrantType.AUTHORIZATION_CODE;
        } else if (AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(authorizationGrantType)) {
            return AuthorizationGrantType.CLIENT_CREDENTIALS;
        } else if (AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(authorizationGrantType)) {
            return AuthorizationGrantType.REFRESH_TOKEN;
        } else if (AuthorizationGrantType.DEVICE_CODE.getValue().equals(authorizationGrantType)) {
            return AuthorizationGrantType.DEVICE_CODE;
        }
        // Custom authorization grant type
        return new AuthorizationGrantType(authorizationGrantType);
    }


    /**
     * 将json转为map
     *
     * @param data json
     * @return map对象
     */
    private Map<String, Object> parseMap(String data) {
        try {
            return MAPPER.readValue(data, new TypeReference<>() {
            });
        } catch (Exception ex) {
            throw new IllegalArgumentException(ex.getMessage(), ex);
        }
    }

    /**
     * 将map对象转为json字符串
     *
     * @param metadata map对象
     * @return json字符串
     */
    private String writeMap(Map<String, Object> metadata) {
        try {
            return MAPPER.writeValueAsString(metadata);
        } catch (Exception ex) {
            throw new IllegalArgumentException(ex.getMessage(), ex);
        }
    }

}
